#!/usr/bin/env python3

"""
    Script to show SFP EEPROM and presence status.
    This script gets the SFP data from State DB, unlike sfputil
    which accesses the transceiver directly.
"""

import ast
import os
import re
import sys
from typing import Dict

import click
from natsort import natsorted
from sonic_py_common.interface import front_panel_prefix, backplane_prefix, inband_prefix, recirc_prefix
from sonic_py_common import multi_asic
from tabulate import tabulate
from utilities_common import multi_asic as multi_asic_util

# Mock the redis DB for unit test purposes
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        test_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, test_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()
except KeyError:
    pass

# TODO: We should share these maps and the formatting functions between sfputil and sfpshow
QSFP_DATA_MAP = {
    'model': 'Vendor PN',
    'vendor_oui': 'Vendor OUI',
    'vendor_date': 'Vendor Date Code(YYYY-MM-DD Lot)',
    'manufacturer': 'Vendor Name',
    'vendor_rev': 'Vendor Rev',
    'serial': 'Vendor SN',
    'type': 'Identifier',
    'ext_identifier': 'Extended Identifier',
    'ext_rateselect_compliance': 'Extended RateSelect Compliance',
    'cable_length': 'cable_length',
    'cable_type': 'Length',
    'nominal_bit_rate': 'Nominal Bit Rate(100Mbs)',
    'specification_compliance': 'Specification compliance',
    'encoding': 'Encoding',
    'connector': 'Connector',
    'application_advertisement': 'Application Advertisement'
}

SFP_DOM_CHANNEL_MONITOR_MAP = {
    'rx1power': 'RXPower',
    'tx1bias': 'TXBias',
    'tx1power': 'TXPower'
}

SFP_DOM_CHANNEL_THRESHOLD_MAP = {
    'txpowerhighalarm':   'TxPowerHighAlarm',
    'txpowerlowalarm':    'TxPowerLowAlarm',
    'txpowerhighwarning': 'TxPowerHighWarning',
    'txpowerlowwarning':  'TxPowerLowWarning',
    'rxpowerhighalarm':   'RxPowerHighAlarm',
    'rxpowerlowalarm':    'RxPowerLowAlarm',
    'rxpowerhighwarning': 'RxPowerHighWarning',
    'rxpowerlowwarning':  'RxPowerLowWarning',
    'txbiashighalarm':    'TxBiasHighAlarm',
    'txbiaslowalarm':     'TxBiasLowAlarm',
    'txbiashighwarning':  'TxBiasHighWarning',
    'txbiaslowwarning':   'TxBiasLowWarning'
}

QSFP_DOM_CHANNEL_THRESHOLD_MAP = {
    'rxpowerhighalarm':   'RxPowerHighAlarm',
    'rxpowerlowalarm':    'RxPowerLowAlarm',
    'rxpowerhighwarning': 'RxPowerHighWarning',
    'rxpowerlowwarning':  'RxPowerLowWarning',
    'txbiashighalarm':    'TxBiasHighAlarm',
    'txbiaslowalarm':     'TxBiasLowAlarm',
    'txbiashighwarning':  'TxBiasHighWarning',
    'txbiaslowwarning':   'TxBiasLowWarning'
}

DOM_MODULE_THRESHOLD_MAP = {
    'temphighalarm':  'TempHighAlarm',
    'templowalarm':   'TempLowAlarm',
    'temphighwarning': 'TempHighWarning',
    'templowwarning': 'TempLowWarning',
    'vcchighalarm':   'VccHighAlarm',
    'vcclowalarm':    'VccLowAlarm',
    'vcchighwarning': 'VccHighWarning',
    'vcclowwarning':  'VccLowWarning'
}

QSFP_DOM_CHANNEL_MONITOR_MAP = {
    'rx1power': 'RX1Power',
    'rx2power': 'RX2Power',
    'rx3power': 'RX3Power',
    'rx4power': 'RX4Power',
    'tx1bias':  'TX1Bias',
    'tx2bias':  'TX2Bias',
    'tx3bias':  'TX3Bias',
    'tx4bias':  'TX4Bias',
    'tx1power': 'TX1Power',
    'tx2power': 'TX2Power',
    'tx3power': 'TX3Power',
    'tx4power': 'TX4Power'
}

QSFP_DD_DOM_CHANNEL_MONITOR_MAP = {
    'rx1power': 'RX1Power',
    'rx2power': 'RX2Power',
    'rx3power': 'RX3Power',
    'rx4power': 'RX4Power',
    'rx5power': 'RX5Power',
    'rx6power': 'RX6Power',
    'rx7power': 'RX7Power',
    'rx8power': 'RX8Power',
    'tx1bias':  'TX1Bias',
    'tx2bias':  'TX2Bias',
    'tx3bias':  'TX3Bias',
    'tx4bias':  'TX4Bias',
    'tx5bias':  'TX5Bias',
    'tx6bias':  'TX6Bias',
    'tx7bias':  'TX7Bias',
    'tx8bias':  'TX8Bias',
    'tx1power': 'TX1Power',
    'tx2power': 'TX2Power',
    'tx3power': 'TX3Power',
    'tx4power': 'TX4Power',
    'tx5power': 'TX5Power',
    'tx6power': 'TX6Power',
    'tx7power': 'TX7Power',
    'tx8power': 'TX8Power'
}

DOM_MODULE_MONITOR_MAP = {
    'temperature': 'Temperature',
    'voltage': 'Vcc'
}

DOM_CHANNEL_THRESHOLD_UNIT_MAP = {
    'txpowerhighalarm':   'dBm',
    'txpowerlowalarm':    'dBm',
    'txpowerhighwarning': 'dBm',
    'txpowerlowwarning':  'dBm',
    'rxpowerhighalarm':   'dBm',
    'rxpowerlowalarm':    'dBm',
    'rxpowerhighwarning': 'dBm',
    'rxpowerlowwarning':  'dBm',
    'txbiashighalarm':    'mA',
    'txbiaslowalarm':     'mA',
    'txbiashighwarning':  'mA',
    'txbiaslowwarning':   'mA'
}

DOM_MODULE_THRESHOLD_UNIT_MAP = {
    'temphighalarm':   'C',
    'templowalarm':    'C',
    'temphighwarning': 'C',
    'templowwarning':  'C',
    'vcchighalarm':    'Volts',
    'vcclowalarm':     'Volts',
    'vcchighwarning':  'Volts',
    'vcclowwarning':   'Volts'
}

DOM_VALUE_UNIT_MAP = {
    'rx1power': 'dBm',
    'rx2power': 'dBm',
    'rx3power': 'dBm',
    'rx4power': 'dBm',
    'tx1bias': 'mA',
    'tx2bias': 'mA',
    'tx3bias': 'mA',
    'tx4bias': 'mA',
    'tx1power': 'dBm',
    'tx2power': 'dBm',
    'tx3power': 'dBm',
    'tx4power': 'dBm',
    'temperature': 'C',
    'voltage': 'Volts'
}

QSFP_DD_DOM_VALUE_UNIT_MAP = {
    'rx1power': 'dBm',
    'rx2power': 'dBm',
    'rx3power': 'dBm',
    'rx4power': 'dBm',
    'rx5power': 'dBm',
    'rx6power': 'dBm',
    'rx7power': 'dBm',
    'rx8power': 'dBm',
    'tx1bias': 'mA',
    'tx2bias': 'mA',
    'tx3bias': 'mA',
    'tx4bias': 'mA',
    'tx5bias': 'mA',
    'tx6bias': 'mA',
    'tx7bias': 'mA',
    'tx8bias': 'mA',
    'tx1power': 'dBm',
    'tx2power': 'dBm',
    'tx3power': 'dBm',
    'tx4power': 'dBm',
    'tx5power': 'dBm',
    'tx6power': 'dBm',
    'tx7power': 'dBm',
    'tx8power': 'dBm',
    'temperature': 'C',
    'voltage': 'Volts'
}


def display_invalid_intf_eeprom(intf_name):
    output = intf_name + ': SFP EEPROM Not detected\n'
    click.echo(output)


def display_invalid_intf_presence(intf_name):
    header = ['Port', 'Presence']
    port_table = []
    port_table.append((intf_name, 'Not present'))
    click.echo(tabulate(port_table, header))


class SFPShow(object):

    def __init__(self, intf_name, namespace_option, dump_dom=False):
        super(SFPShow, self).__init__()
        self.db = None
        self.intf_name = intf_name
        self.dump_dom = dump_dom
        self.table = []
        self.intf_eeprom: Dict[str, str] = {}
        self.multi_asic = multi_asic_util.MultiAsic(namespace_option=namespace_option)

    # Convert dict values to cli output string
    def format_dict_value_to_string(self, sorted_key_table,
                                    dom_info_dict, dom_value_map,
                                    dom_unit_map, alignment=0):
        output = ''
        indent = ' ' * 8
        separator = ": "
        for key in sorted_key_table:
            if dom_info_dict is not None and key in dom_info_dict and dom_info_dict[key] != 'N/A':
                value = dom_info_dict[key]
                units = ''
                if type(value) != str or (value != 'Unknown' and not value.endswith(dom_unit_map[key])):
                    units = dom_unit_map[key]
                output += '{}{}{}{}{}\n'.format((indent * 2),
                                                dom_value_map[key],
                                                separator.rjust(len(separator) + alignment - len(dom_value_map[key])),
                                                value,
                                                units)
        return output

    # Convert sfp info in DB to cli output string
    def convert_sfp_info_to_output_string(self, sfp_info_dict):
        indent = ' ' * 8
        output = ''

        sorted_qsfp_data_map_keys = sorted(QSFP_DATA_MAP, key=QSFP_DATA_MAP.get)
        for key in sorted_qsfp_data_map_keys:
            if key == 'cable_type':
                output += '{}{}: {}\n'.format(indent, sfp_info_dict['cable_type'], sfp_info_dict['cable_length'])
            elif key == 'cable_length':
                pass
            elif key == 'specification_compliance':
                if sfp_info_dict['type'] == "QSFP-DD Double Density 8X Pluggable Transceiver":
                    output += '{}{}: {}\n'.format(indent, QSFP_DATA_MAP[key], sfp_info_dict[key])
                else:
                    output += '{}{}:\n'.format(indent, QSFP_DATA_MAP['specification_compliance'])

                    spec_compliance_dict = {}
                    try:
                        spec_compliance_dict = ast.literal_eval(sfp_info_dict['specification_compliance'])
                        sorted_compliance_key_table = natsorted(spec_compliance_dict)
                        for compliance_key in sorted_compliance_key_table:
                            output += '{}{}: {}\n'.format((indent * 2), compliance_key, spec_compliance_dict[compliance_key])
                    except ValueError as e:
                        output += '{}N/A\n'.format((indent * 2))
            else:
                output += '{}{}: {}\n'.format(indent, QSFP_DATA_MAP[key], sfp_info_dict[key])

        return output

    # Convert DOM sensor info in DB to CLI output string
    def convert_dom_to_output_string(self, sfp_type, dom_info_dict):
        indent = ' ' * 8
        output_dom = ''
        channel_threshold_align = 18
        module_threshold_align = 15

        if sfp_type.startswith('QSFP'):
            # Channel Monitor
            if sfp_type.startswith('QSFP-DD'):
                output_dom += (indent + 'ChannelMonitorValues:\n')
                sorted_key_table = natsorted(QSFP_DD_DOM_CHANNEL_MONITOR_MAP)
                output_channel = self.format_dict_value_to_string(
                    sorted_key_table, dom_info_dict,
                    QSFP_DD_DOM_CHANNEL_MONITOR_MAP,
                    QSFP_DD_DOM_VALUE_UNIT_MAP)
                output_dom += output_channel
            else:
                output_dom += (indent + 'ChannelMonitorValues:\n')
                sorted_key_table = natsorted(QSFP_DOM_CHANNEL_MONITOR_MAP)
                output_channel = self.format_dict_value_to_string(
                    sorted_key_table, dom_info_dict,
                    QSFP_DOM_CHANNEL_MONITOR_MAP,
                    DOM_VALUE_UNIT_MAP)
                output_dom += output_channel

            # Channel Threshold
            if sfp_type.startswith('QSFP-DD'):
                dom_map = SFP_DOM_CHANNEL_THRESHOLD_MAP
            else:
                dom_map = QSFP_DOM_CHANNEL_THRESHOLD_MAP

            output_dom += (indent + 'ChannelThresholdValues:\n')
            sorted_key_table = natsorted(dom_map)
            output_channel_threshold = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                dom_map,
                DOM_CHANNEL_THRESHOLD_UNIT_MAP,
                channel_threshold_align)
            output_dom += output_channel_threshold

            # Module Monitor
            output_dom += (indent + 'ModuleMonitorValues:\n')
            sorted_key_table = natsorted(DOM_MODULE_MONITOR_MAP)
            output_module = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                DOM_MODULE_MONITOR_MAP,
                DOM_VALUE_UNIT_MAP)
            output_dom += output_module

            # Module Threshold
            output_dom += (indent + 'ModuleThresholdValues:\n')
            sorted_key_table = natsorted(DOM_MODULE_THRESHOLD_MAP)
            output_module_threshold = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                DOM_MODULE_THRESHOLD_MAP,
                DOM_MODULE_THRESHOLD_UNIT_MAP,
                module_threshold_align)
            output_dom += output_module_threshold

        else:
            output_dom += (indent + 'MonitorData:\n')
            sorted_key_table = natsorted(SFP_DOM_CHANNEL_MONITOR_MAP)
            output_channel = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                SFP_DOM_CHANNEL_MONITOR_MAP,
                DOM_VALUE_UNIT_MAP)
            output_dom += output_channel

            sorted_key_table = natsorted(DOM_MODULE_MONITOR_MAP)
            output_module = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                DOM_MODULE_MONITOR_MAP,
                DOM_VALUE_UNIT_MAP)
            output_dom += output_module

            output_dom += (indent + 'ThresholdData:\n')

            # Module Threshold
            sorted_key_table = natsorted(DOM_MODULE_THRESHOLD_MAP)
            output_module_threshold = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                DOM_MODULE_THRESHOLD_MAP,
                DOM_MODULE_THRESHOLD_UNIT_MAP,
                module_threshold_align)
            output_dom += output_module_threshold

            # Channel Threshold
            sorted_key_table = natsorted(SFP_DOM_CHANNEL_THRESHOLD_MAP)
            output_channel_threshold = self.format_dict_value_to_string(
                sorted_key_table, dom_info_dict,
                SFP_DOM_CHANNEL_THRESHOLD_MAP,
                DOM_CHANNEL_THRESHOLD_UNIT_MAP,
                channel_threshold_align)
            output_dom += output_channel_threshold

        return output_dom

    # Convert sfp info and dom sensor info in DB to cli output string
    def convert_interface_sfp_info_to_cli_output_string(self, state_db, interface_name, dump_dom):
        output = ''

        sfp_info_dict = state_db.get_all(state_db.STATE_DB, 'TRANSCEIVER_INFO|{}'.format(interface_name))
        output = 'SFP EEPROM detected\n'
        sfp_info_output = self.convert_sfp_info_to_output_string(sfp_info_dict)
        output += sfp_info_output

        if dump_dom:
            sfp_type = sfp_info_dict['type']
            dom_info_dict = state_db.get_all(state_db.STATE_DB, 'TRANSCEIVER_DOM_SENSOR|{}'.format(interface_name))
            dom_output = self.convert_dom_to_output_string(sfp_type, dom_info_dict)
            output += dom_output

        return output

    @multi_asic_util.run_on_multi_asic
    def get_eeprom(self):
        if self.intf_name is not None:
            presence = self.db.exists(self.db.STATE_DB, 'TRANSCEIVER_INFO|{}'.format(self.intf_name))
            if presence:
                self.intf_eeprom[self.intf_name] = self.convert_interface_sfp_info_to_cli_output_string(
                    self.db, self.intf_name, self.dump_dom)
            else:
                self.intf_eeprom[self.intf_name] = "SFP EEPROM Not detected\n"
        else:
            port_table_keys = self.db.keys(self.db.APPL_DB, "PORT_TABLE:*")
            for i in port_table_keys:
                interface = re.split(':', i, maxsplit=1)[-1].strip()
                if interface and interface.startswith(front_panel_prefix()) and not interface.startswith((backplane_prefix(), inband_prefix(), recirc_prefix())):
                    presence = self.db.exists(self.db.STATE_DB, 'TRANSCEIVER_INFO|{}'.format(interface))
                    if presence:
                        self.intf_eeprom[interface] = self.convert_interface_sfp_info_to_cli_output_string(
                            self.db, interface, self.dump_dom)
                    else:
                        self.intf_eeprom[interface] = "SFP EEPROM Not detected\n"


    @multi_asic_util.run_on_multi_asic
    def get_presence(self):
        port_table = []

        if self.intf_name is not None:
            presence = self.db.exists(self.db.STATE_DB, 'TRANSCEIVER_INFO|{}'.format(self.intf_name))
            if presence:
                port_table.append((self.intf_name, 'Present'))
            else:
                port_table.append((self.intf_name, 'Not present'))
        else:
            port_table_keys = self.db.keys(self.db.APPL_DB, "PORT_TABLE:*")
            for i in port_table_keys:
                key = re.split(':', i, maxsplit=1)[-1].strip()
                if key and key.startswith(front_panel_prefix()) and not key.startswith((backplane_prefix(), inband_prefix(), recirc_prefix())):
                    presence = self.db.exists(self.db.STATE_DB, 'TRANSCEIVER_INFO|{}'.format(key))
                    if presence:
                        port_table.append((key, 'Present'))
                    else:
                        port_table.append((key, 'Not present'))

        self.table += port_table

    def display_eeprom(self):
        click.echo("\n".join([f"{k}: {v}" for k, v in natsorted(self.intf_eeprom.items())]))

    def display_presence(self):
        header = ['Port', 'Presence']
        sorted_port_table = natsorted(self.table)
        click.echo(tabulate(sorted_port_table, header))

# This is our main entrypoint - the main 'sfpshow' command


@click.group()
def cli():
    """sfpshow - Command line utility for display SFP transceivers information"""
    pass

# 'eeprom' subcommand


@cli.command()
@click.option('-p', '--port', metavar='<port_name>', help="Display SFP EEPROM data for port <port_name> only")
@click.option('-d', '--dom', 'dump_dom', is_flag=True, help="Also display Digital Optical Monitoring (DOM) data")
@click.option('-n', '--namespace', default=None, help="Display interfaces for specific namespace")
def eeprom(port, dump_dom, namespace):
    if port and multi_asic.is_multi_asic() and namespace is None:
        try:
            namespace = multi_asic.get_namespace_for_port(port)
        except Exception:
            display_invalid_intf_eeprom(port)
            sys.exit(1)

    sfp = SFPShow(port, namespace, dump_dom)
    sfp.get_eeprom()
    sfp.display_eeprom()

# 'presence' subcommand


@cli.command()
@click.option('-p', '--port', metavar='<port_name>', help="Display SFP presence for port <port_name> only")
@click.option('-n', '--namespace', default=None, help="Display interfaces for specific namespace")
def presence(port, namespace):
    if port and multi_asic.is_multi_asic() and namespace is None:
        try:
            namespace = multi_asic.get_namespace_for_port(port)
        except Exception:
            display_invalid_intf_presence(port)
            sys.exit(1)

    sfp = SFPShow(port, namespace)
    sfp.get_presence()
    sfp.display_presence()


if __name__ == "__main__":
    cli()
